{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook for preparing and saving MOLECULAR graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import pickle\n",
    "import time\n",
    "import os\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download ZINC dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('molecules.zip'):\n",
    "    print('downloading..')\n",
    "    !curl https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1 -o molecules.zip -J -L -k\n",
    "    !unzip molecules.zip -d ../\n",
    "    # !tar -xvf molecules.zip -C ../\n",
    "else:\n",
    "    print('File already downloaded')\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to DGL format and save with pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir('../../') # go to root folder of the project\n",
    "print(os.getcwd())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from data.molecules import MoleculeDatasetDGL \n",
    "\n",
    "from data.data import LoadData\n",
    "from torch.utils.data import DataLoader\n",
    "from data.molecules import MoleculeDataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'ZINC'\n",
    "dataset = MoleculeDatasetDGL(DATASET_NAME) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_histo_graphs(dataset, title):\n",
    "    # histogram of graph sizes\n",
    "    graph_sizes = []\n",
    "    for graph in dataset:\n",
    "        graph_sizes.append(graph[0].number_of_nodes())\n",
    "    plt.figure(1)\n",
    "    plt.hist(graph_sizes, bins=20)\n",
    "    plt.title(title)\n",
    "    plt.show()\n",
    "    graph_sizes = torch.Tensor(graph_sizes)\n",
    "    print('min/max :',graph_sizes.min().long().item(),graph_sizes.max().long().item())\n",
    "    \n",
    "plot_histo_graphs(dataset.train,'trainset')\n",
    "plot_histo_graphs(dataset.val,'valset')\n",
    "plot_histo_graphs(dataset.test,'testset')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(dataset.train))\n",
    "print(len(dataset.val))\n",
    "print(len(dataset.test))\n",
    "\n",
    "print(dataset.train[0])\n",
    "print(dataset.val[0])\n",
    "print(dataset.test[0])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_atom_type = 28\n",
    "num_bond_type = 4\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "with open('data/molecules/ZINC.pkl','wb') as f:\n",
    "        pickle.dump([dataset.train,dataset.val,dataset.test,num_atom_type,num_bond_type],f)\n",
    "print('Time (sec):',time.time() - start)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test load function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'ZINC'\n",
    "dataset = LoadData(DATASET_NAME)\n",
    "trainset, valset, testset = dataset.train, dataset.val, dataset.test\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "collate = MoleculeDataset.collate\n",
    "print(MoleculeDataset)\n",
    "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('./data/molecules/') # go to root folder of the project\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download AqSol dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isfile('aqsol_graph_raw.zip'):\n",
    "    print('downloading..')\n",
    "    !curl https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1 -o aqsol_graph_raw.zip -J -L -k\n",
    "    !unzip aqsol_graph_raw.zip -d ./\n",
    "else:\n",
    "    print('File already downloaded')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convert to DGL format and save with pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.chdir('../../') # go to root folder of the project\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'AqSol'\n",
    "dataset = MoleculeDatasetDGL(DATASET_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_histo_graphs(dataset.train,'trainset')\n",
    "plot_histo_graphs(dataset.val,'valset')\n",
    "plot_histo_graphs(dataset.test,'testset')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(dataset.train))\n",
    "print(len(dataset.val))\n",
    "print(len(dataset.test))\n",
    "\n",
    "print(dataset.train[0])\n",
    "print(dataset.val[0])\n",
    "print(dataset.test[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_atom_type = dataset.num_atom_type\n",
    "num_bond_type = dataset.num_bond_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "with open('data/molecules/AQSOL.pkl','wb') as f:\n",
    "        pickle.dump([dataset.train,dataset.val,dataset.test,num_atom_type,num_bond_type],f)\n",
    "print('Time (sec):',time.time() - start)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test load function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_NAME = 'AQSOL'\n",
    "dataset = LoadData(DATASET_NAME)\n",
    "trainset, valset, testset = dataset.train, dataset.val, dataset.test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 10\n",
    "collate = MoleculeDataset.collate\n",
    "print(MoleculeDataset)\n",
    "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
